import numpy as np
import pandas as pd
import re
from scipy.optimize import minimize
from tqdm import tqdm
from joblib import Parallel, delayed
import logging
import time
import matplotlib.pyplot as plt
import os

# Set up logging
logging.basicConfig(filename='symbolic_fit.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Primes list
PRIMES = [
    2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
    73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151,
    157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233,
    239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317,
    331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419,
    421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503,
    509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607,
    613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701,
    709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811,
    821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911,
    919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997
]

phi = (1 + np.sqrt(5)) / 2
fib_cache = {}

def fib_real(n):
    if n in fib_cache:
        return fib_cache[n]
    if n > 100:
        return 0.0
    term1 = phi**n / np.sqrt(5)
    term2 = ((1/phi)**n) * np.cos(np.pi * n)
    result = term1 - term2
    fib_cache[n] = result
    return result

def D(n, beta, r=1.0, k=1.0, Omega=1.0, base=2, scale=1.0):
    try:
        Fn_beta = fib_real(n + beta)
        idx = int(np.floor(n + beta) + len(PRIMES)) % len(PRIMES)
        Pn_beta = PRIMES[idx]
        # Use logarithmic form to avoid overflow
        log_dyadic = (n + beta) * np.log(max(base, 1e-10))
        if log_dyadic > 500 or log_dyadic < -500:  # Prevent overflow/underflow
            return None
        log_val = np.log(max(scale, 1e-30)) + np.log(phi) + np.log(max(abs(Fn_beta), 1e-30)) + log_dyadic + np.log(Pn_beta) + np.log(max(Omega, 1e-30))
        if n > 1000:
            log_val += np.log(np.log(n) / np.log(1000))
        if not np.isfinite(log_val):
            return None
        val = np.exp(log_val) * np.sign(Fn_beta)  # Preserve sign for negative values
        return np.sqrt(max(abs(val), 1e-30)) * (r ** k) * np.sign(val)  # Ensure correct sign
    except Exception as e:
        logging.error(f"D failed: n={n}, beta={beta}, error={e}")
        return None

def invert_D(value, r=1.0, k=1.0, Omega=1.0, base=2, scale=1.0, max_n=500):
    candidates = []
    log_val = np.log10(max(abs(value), 1e-30))
    max_n = min(500, max(200, int(100 * abs(log_val))))  # Adjust max_n based on value magnitude
    n_values = np.linspace(0, max_n, 20)  # Reduced search space
    scale_factors = np.logspace(max(log_val - 2, -10), min(log_val + 2, 10), num=5)
    try:
        for n in tqdm(n_values, desc=f"invert_D for {value:.2e}", leave=False):
            for beta in np.linspace(0, 1, 3):
                for dynamic_scale in scale_factors:
                    for r_local in [0.5, 1.0]:
                        for k_local in [0.5, 1.0]:
                            val = D(n, beta, r_local, k_local, Omega, base, scale * dynamic_scale)
                            if val is None or not np.isfinite(val):
                                continue
                            diff = abs(val - value)
                            rel_diff = diff / max(abs(value), 1e-30)
                            if rel_diff < 0.1:  # Only keep candidates with low relative error
                                candidates.append((diff, n, beta, dynamic_scale, r_local, k_local))
        if not candidates:
            logging.error(f"invert_D: No valid candidates for value {value}")
            return None
        candidates = sorted(candidates, key=lambda x: x[0])[:5]  # Keep top 5 candidates
        valid_vals = [D(n, beta, r, k, Omega, base, scale * s) 
                      for _, n, beta, s, r, k in candidates if D(n, beta, r, k, Omega, base, scale * s) is not None]
        if not valid_vals:
            return None
        emergent_uncertainty = np.std(valid_vals) if len(valid_vals) > 1 else abs(valid_vals[0]) * 0.01
        if not np.isfinite(emergent_uncertainty):
            logging.error(f"invert_D: Non-finite emergent uncertainty for value {value}")
            return None
        best = candidates[0]
        return best[1], best[2], best[3], emergent_uncertainty, best[4], best[5]
    except Exception as e:
        logging.error(f"invert_D failed for value {value}: {e}")
        return None

def parse_codata_ascii(filename):
    constants = []
    pattern = re.compile(r"^\s*(.*?)\s{2,}(\-?\d+\.?\d*(?:\s*[Ee][\+\-]?\d+)?(?:\.\.\.)?)\s+(\-?\d+\.?\d*(?:\s*[Ee][\+\-]?\d+)?|exact)\s+(\S.*)")
    with open(filename, "r") as f:
        for line in f:
            if line.startswith("Quantity") or line.strip() == "" or line.startswith("-"):
                continue
            m = pattern.match(line)
            if m:
                name, value_str, uncert_str, unit = m.groups()
                try:
                    value = float(value_str.replace("...", "").replace(" ", ""))
                    uncertainty = 0.0 if uncert_str == "exact" else float(uncert_str.replace("...", "").replace(" ", ""))
                    constants.append({
                        "name": name.strip(),
                        "value": value,
                        "uncertainty": uncertainty,
                        "unit": unit.strip()
                    })
                except Exception as e:
                    logging.warning(f"Failed parsing line: {line.strip()} - {e}")
                    continue
    return pd.DataFrame(constants)

def check_physical_consistency(df_results):
    bad_data = []
    relations = [
        ('Planck constant', 'reduced Planck constant', lambda x, y: abs(x['scale'] / y['scale'] - 2 * np.pi), 0.1, 'scale ratio vs. 2π'),
        ('proton mass', 'proton-electron mass ratio', lambda x, y: abs(x['n'] - y['n'] - np.log10(1836)), 0.5, 'n difference vs. log(proton-electron ratio)'),
        ('Fermi coupling constant', 'weak mixing angle', lambda x, y: abs(x['scale'] - y['scale'] / np.sqrt(2)), 0.1, 'scale vs. sin²θ_W/√2'),
        ('tau energy equivalent', 'tau mass energy equivalent in MeV', lambda x, y: abs(x['value'] - y['value']), 0.01, 'value consistency'),
        ('proton mass', 'electron mass', 'proton-electron mass ratio', 
         lambda x, y, z: abs(z['n'] - abs(x['n'] - y['n'])), 10.0, 'n inconsistency for mass ratio')
    ]
    for relation in relations:
        try:
            if len(relation) == 4:  # Two-constant relation
                name1, name2, check_func, threshold, reason = relation
                row1 = df_results[df_results['name'] == name1].iloc[0]
                row2 = df_results[df_results['name'] == name2].iloc[0]
                if check_func(row1, row2) > threshold:
                    bad_data.append((name1, f"Physical inconsistency: {reason}"))
                    bad_data.append((name2, f"Physical inconsistency: {reason}"))
            else:  # Three-constant relation (e.g., mass ratios)
                name1, name2, name3, check_func, threshold, reason = relation
                row1 = df_results[df_results['name'] == name1].iloc[0]
                row2 = df_results[df_results['name'] == name2].iloc[0]
                row3 = df_results[df_results['name'] == name3].iloc[0]
                if check_func(row1, row2, row3) > threshold:
                    bad_data.append((name3, f"Physical inconsistency: {reason}"))
        except IndexError:
            continue
    return bad_data

def total_error(params, df_subset):
    r, k, Omega, base, scale = params
    df_results = symbolic_fit_all_constants(df_subset, base=base, Omega=Omega, r=r, k=k, scale=scale)
    if df_results.empty:
        return np.inf
    error = df_results['error'].mean()
    return error if np.isfinite(error) else np.inf

def process_constant(row, r, k, Omega, base, scale):
    try:
        name, value, uncertainty, unit = row['name'], row['value'], row['uncertainty'], row['unit']
        abs_value = abs(value)
        sign = np.sign(value)
        result = invert_D(abs_value, r=r, k=k, Omega=Omega, base=base, scale=scale)
        if result is None:
            return {
                'name': name, 'value': value, 'unit': unit, 'n': None, 'beta': None, 'approx': None,
                'error': None, 'uncertainty': None, 'scale': None, 'bad_data': True,
                'bad_data_reason': 'No valid fit found'
            }
        n, beta, dynamic_scale, emergent_uncertainty, r_local, k_local = result
        approx = D(n, beta, r_local, k_local, Omega, base, scale * dynamic_scale)
        if approx is None:
            return {
                'name': name, 'value': value, 'unit': unit, 'n': None, 'beta': None, 'approx': None,
                'error': None, 'uncertainty': None, 'scale': None, 'bad_data': True,
                'bad_data_reason': 'D function returned None'
            }
        approx *= sign  # Apply original sign
        error = abs(approx - value)
        rel_error = error / max(abs(value), 1e-30)
        bad_data = False
        bad_data_reason = ""
        if rel_error > 0.5:
            bad_data = True
            bad_data_reason += f"High relative uncertainty ({rel_error:.2e} > 0.5); "
        if emergent_uncertainty > uncertainty * 10 or emergent_uncertainty < uncertainty / 10:
            bad_data = True
            bad_data_reason += f"Uncertainty deviates from emergent ({emergent_uncertainty:.2e} vs. {uncertainty:.2e}); "
        return {
            'name': name, 'value': value, 'unit': unit, 'n': n, 'beta': beta, 'approx': approx,
            'error': error, 'rel_error': rel_error, 'uncertainty': emergent_uncertainty, 
            'scale': scale * dynamic_scale, 'bad_data': bad_data, 'bad_data_reason': bad_data_reason
        }
    except Exception as e:
        logging.error(f"process_constant failed for {row['name']}: {e}")
        return {
            'name': name, 'value': value, 'unit': unit, 'n': None, 'beta': None, 'approx': None,
            'error': None, 'uncertainty': None, 'scale': None, 'bad_data': True,
            'bad_data_reason': f"Processing error: {str(e)}"
        }

def symbolic_fit_all_constants(df, base=2, Omega=1.0, r=1.0, k=1.0, scale=1.0):
    logging.info("Starting symbolic fit for all constants...")
    results = Parallel(n_jobs=-1, timeout=15, backend='loky', maxtasksperchild=100)(
        delayed(process_constant)(row, r, k, Omega, base, scale) 
        for row in tqdm(df.to_dict('records'), desc="Fitting constants")
    )
    results = [r for r in results if r is not None]
    df_results = pd.DataFrame(results)

    if not df_results.empty:
        df_results['bad_data'] = False
        df_results['bad_data_reason'] = ''
        for name in df_results['name'].unique():
            mask = df_results['name'] == name
            if df_results.loc[mask, 'uncertainty'].notnull().any():
                uncertainties = df_results.loc[mask, 'uncertainty'].dropna()
                if not uncertainties.empty:
                    Q1, Q3 = np.percentile(uncertainties, [25, 75])
                    IQR = Q3 - Q1
                    outlier_mask = (uncertainties < Q1 - 1.5 * IQR) | (uncertainties > Q3 + 1.5 * IQR)
                    if outlier_mask.any():
                        df_results.loc[mask & df_results['uncertainty'].isin(uncertainties[outlier_mask]), 'bad_data'] = True
                        df_results.loc[mask & df_results['uncertainty'].isin(uncertainties[outlier_mask]), 'bad_data_reason'] += 'Uncertainty outlier; '

        high_rel_error_mask = df_results['rel_error'] > 0.5
        df_results.loc[high_rel_error_mask, 'bad_data'] = True
        df_results.loc[high_rel_error_mask, 'bad_data_reason'] += df_results.loc[high_rel_error_mask, 'rel_error'].apply(lambda x: f"High relative uncertainty ({x:.2e} > 0.5); ")

        high_uncertainty_mask = (df_results['uncertainty'] > 2 * df_results['emergent_uncertainty']) | (df_results['uncertainty'] < 0.5 * df_results['emergent_uncertainty'])
        df_results.loc[high_uncertainty_mask, 'bad_data'] = True
        df_results.loc[high_uncertainty_mask, 'bad_data_reason'] += df_results.loc[high_uncertainty_mask].apply(
            lambda row: f"Uncertainty deviates from emergent ({row['uncertainty']:.2e} vs. {row['emergent_uncertainty']:.2e}); ", axis=1)

        bad_data = check_physical_consistency(df_results)
        for name, reason in bad_data:
            df_results.loc[df_results['name'] == name, 'bad_data'] = True
            df_results.loc[df_results['name'] == name, 'bad_data_reason'] += reason + '; '

    logging.info("Symbolic fit completed.")
    return df_results

def main():
    start_time = time.time()
    if not os.path.exists("allascii.txt"):
        raise FileNotFoundError("allascii.txt not found in the current directory")
    df = parse_codata_ascii("allascii.txt")
    logging.info(f"Parsed {len(df)} constants")

    # Optimize parameters on worst-performing constants
    worst_names = [
        'muon mag. mom. to Bohr magneton ratio', 'electron-deuteron mag. mom. ratio',
        'proton-electron mass ratio', 'neutron-electron mass ratio',
        'electron mag. mom.', 'neutron mag. mom.', 'alpha particle mass energy equivalent in MeV'
    ]
    subset_df = df[df['name'].isin(worst_names)]
    if subset_df.empty:
        subset_df = df.head(50)  # Fallback to first 50 constants
    init_params = [1.0, 1.0, 1.0, 2.0, 1.0]  # r, k, Omega, base, scale
    bounds = [(1e-5, 10), (1e-5, 10), (1e-5, 10), (1.5, 10), (1e-5, 100)]
    
    print("Optimizing symbolic model parameters for worst fits...")
    try:
        res = minimize(total_error, init_params, args=(subset_df,), bounds=bounds, method='L-BFGS-B', options={'maxiter': 50, 'disp': True})
        if not res.success:
            logging.warning(f"Optimization failed: {res.message}")
        r_opt, k_opt, Omega_opt, base_opt, scale_opt = res.x
        print(f"Optimization complete. Found parameters:\nr = {r_opt:.6f}, k = {k_opt:.6f}, Omega = {Omega_opt:.6f}, base = {base_opt:.6f}, scale = {scale_opt:.6f}")
    except Exception as e:
        logging.error(f"Optimization failed: {e}")
        return

    # Run final fit
    df_results = symbolic_fit_all_constants(df, base=base_opt, Omega=Omega_opt, r=r_opt, k=k_opt, scale=scale_opt)
    if not df_results.empty:
        df_results.to_csv("symbolic_fit_results_emergent_fixed.txt", index=False)
        logging.info(f"Saved results to symbolic_fit_results_emergent_fixed.txt")
    else:
        logging.error("No results to save")

    logging.info(f"Total runtime: {time.time() - start_time:.2f} seconds")

    # Display results
    df_results_sorted = df_results.sort_values("error")
    print("\nTop 20 best symbolic fits:")
    print(df_results_sorted.head(20)[['name', 'value', 'unit', 'n', 'beta', 'approx', 'error', 'uncertainty', 'scale', 'bad_data', 'bad_data_reason']].to_string(index=False))

    print("\nTop 20 worst symbolic fits:")
    print(df_results_sorted.tail(20)[['name', 'value', 'unit', 'n', 'beta', 'approx', 'error', 'uncertainty', 'scale', 'bad_data', 'bad_data_reason']].to_string(index=False))

    print("\nPotentially bad data constants summary:")
    bad_data_df = df_results[df_results['bad_data'] == True][['name', 'value', 'error', 'rel_error', 'uncertainty', 'bad_data_reason']]
    print(bad_data_df.to_string(index=False))

    df_results_sorted.to_csv("symbolic_fit_results.txt", sep="\t", index=False)

    # Plotting
    plt.figure(figsize=(10, 5))
    plt.hist(df_results_sorted['error'], bins=50, color='skyblue', edgecolor='black')
    plt.title('Histogram of Absolute Errors in Symbolic Fit')
    plt.xlabel('Absolute Error')
    plt.ylabel('Count')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(10, 5))
    plt.scatter(df_results_sorted['n'], df_results_sorted['error'], alpha=0.5, s=15, c='orange', edgecolors='black')
    plt.title('Absolute Error vs Symbolic Dimension n')
    plt.xlabel('n')
    plt.ylabel('Absolute Error')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(10, 5))
    worst_fits = df_results_sorted.tail(20)
    plt.bar(worst_fits['name'], worst_fits['error'], color='salmon', edgecolor='black')
    plt.xticks(rotation=90)
    plt.title('Absolute Errors for Top 20 Worst Symbolic Fits')
    plt.xlabel('Constant Name')
    plt.ylabel('Absolute Error')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()